/*
    Copyright 2006-2012 Patrik Jonsson, sunrise@familjenjonsson.org

    This file is part of Sunrise.

    Sunrise is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 3 of the License, or
    (at your option) any later version.

    Sunrise is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with Sunrise.  If not, see <http://www.gnu.org/licenses/>.

*/

/// \file
/// FITS I/O routines for the grid classes.


#ifndef __grid_fits__
#define __grid_fits__

#include "grid.h"
#include "CCfits/CCfits"
#include "blitz-fits.h"
#include "loadbalance.h"
#include "boost/foreach.hpp"

/** Saves grid structure to a FITS HDU.  The structure is saved as a
    binary table with a Boolean column indicating whether or not a
    cell is subdivided.  The grid cells are traversed in a depth-first
    fashion.  The quantities of the base grid are saved as keywords.
    See the makegrid-readme for more details. */ 
template<typename cell_data_type/*, typename grid_type*/> 
void mcrx::adaptive_grid<cell_data_type/*,grid_type*/>:: 
save_structure (CCfits::FITS& file, const std::string&hdu,
		const std::string& length_unit,
		bool save_codes) const
{
  using namespace CCfits;
  std::cout << "saving grid structure in HDU " << hdu << std::endl; 
  const string length_unit_comment = (length_unit == "")?"": 
    "[" +length_unit + "]";
    
  Table* output = file.addTable(string (hdu), 0);
  output->addKey("LENGTHUNIT", length_unit, "Length unit for grid");
  output->addKey("MINX", getmin () [0], length_unit_comment);
  output->addKey("MINY", getmin () [1], length_unit_comment);
  output->addKey("MINZ", getmin () [2], length_unit_comment);
  output->addKey("MAXX", getmax () [0], length_unit_comment);
  output->addKey("MAXY", getmax () [1], length_unit_comment);
  output->addKey("MAXZ", getmax () [2], length_unit_comment);
  
  output->addKey("subdivtp", "OCTREE", "Type of grid subdivision");

  // Now go through the grid and accumulate the information in the vectors
  std::cout << "building structural information" << std::endl;

  std::vector<T_code> codes;
  std::vector<T_qpoint::T_pos> qpos;
  std::vector<bool> structure=get_structure (save_codes ? &codes : 0,
					     save_codes ? &qpos : 0);

  output->addColumn(Tlogical, "structure", 1,
		    "True means cell is subdivided, listed in a depth-first fashion" );
  if(save_codes) {
    output->addColumn(Tlonglong, "code", 1, "The 64-bit cell code." );
    output->addColumn(Tuint, "qpos", 3, "The 32-bit quantized cell position." );
    output->addColumn(Tbyte, "level", 1, "The level of the cell.");
  }

  Column& c_structure = output->column("structure");
  c_structure.write(structure, 1);

  if(save_codes) {
    std::vector<uint64_t> c;
    std::vector<uint8_t> l;
    BOOST_FOREACH(T_code& code, codes) {
      c.push_back(code.code());
      l.push_back(code.level());
    }
    output->column("code").write(c,1);
    write(output->column("qpos"), qpos,1);
    output->column("level").write(l,1);
  }
}


/** Constructor creates grid from a grid structure FITS HDU.  When
    execution enters this constructor, the grid_base constructor has
    completed. This function then looks for the refinement mechanism
    and ensures that the file is consistent with the class.  A memory
    block is allocated to hold all the subgrids, so that they are well
    localized in memory, and finally load_recursive is called to
    actually create the structure. The decomp_level is the level of
    the Hilbert curve on which domain decomposition is done. */
template < typename cell_data_type>
mcrx::adaptive_grid<cell_data_type>::
adaptive_grid (CCfits::ExtHDU& input, int decomp_level,
	       const vec3d& translate_origin) : 
  subgrid_block_ (0)
{
  using namespace CCfits;

  float dummy0, dummy1, dummy2; //CCfits can't read doubles
  input.readKey("minx", dummy0);
  input.readKey("miny", dummy1);
  input.readKey("minz", dummy2);
  min_ = vec3d (dummy0, dummy1, dummy2);
  input.readKey("maxx", dummy0);
  input.readKey("maxy", dummy1);
  input.readKey("maxz", dummy2);
  max_ = vec3d (dummy0, dummy1, dummy2);
  min_ -=translate_origin;
  max_ -=translate_origin;

  string subdivtp;
  input.readKey("subdivtp",subdivtp);
  if (subdivtp != "OCTREE") {
    std::cerr << "Error: Octree class can't read grid structure for "
	      << subdivtp << std::endl;
    throw CCfits::FITS::OperationNotSupported ("Wrong grid type");
  }

  // OK, now proceed to load the structure information and create it
  std::vector<bool> structure;
  Column& c_structure = input.column("structure");
  c_structure.read(structure, 1, c_structure.rows() );
  
  // determine domain decomposition
  domain_decomposition(structure, decomp_level);
  
  create_structure(structure);
}

/** Increases the iterator to skip an entire octogrid subtree. The
    iterator should point to the first cell in the subtree that's
    being skipped.  */
template<typename cell_data_type>
void
mcrx::adaptive_grid<cell_data_type>::
skip_suboctree(std::vector<bool>::const_iterator& refined,
	       const std::vector<bool>::const_iterator& e) const
{
  for (int i=0; i<8; ++i) {
    if(*(refined++))
      skip_suboctree(refined, e);
  }
}

  
/** Creates the octree.  A depth-first traversal of the tree is done,
    and if a cell is supposed to be subdivided as indicated by the
    structure vector, the subdivision is done.  Placement new is used
    to make the grid structure compact in memory, using the supplied
    pointer to a preallocated memory block. If supplied a domain
    decomposition vector, it will only create leaf cells within the
    domain. */
template<typename cell_data_type>
void
mcrx::adaptive_grid<cell_data_type>::
load_recursive (std::vector<bool>::const_iterator& refined,
		const std::vector<bool>::const_iterator& e,
		void*& placement)
{
  using namespace std;
  using namespace hilbert;
  const int task = mpi_rank();
  const int decomp_level=domain_.empty() ? 0 : domain_.back().level();

  T_cell_tracker c(this);
  int n=0;

  while(!c.at_end()) {

    assert (refined != e);

    bool inside=true;
    int lowtask, hightask;

    if(!domain_.empty()) {

      // to figure out which task the cell is in, we search the domain
      // vector for the first partition that is not <= the
      // low-extended cell code for the current cell. Because we do
      // low-extension, this gives us the task of the start of the
      // cell.
      const vector<cell_code>::const_iterator tli = 
	lower_bound(domain_.begin(), domain_.end(), 
		    cell_code(c.code()).extend_low(decomp_level),
		    std::less_equal<cell_code>());
      lowtask = tli-domain_.begin()-1;

      // now, if the cell level is lower than the decomp level, the
      // cell can straddle domains. in that case we need to know what
      // task the end of the cell is, so we search for the first
      // partition that is not < the high-extended Hilbert code as
      // well
      if(c.code().level()<decomp_level) {
	const vector<cell_code>::const_iterator thi = 
	  lower_bound(domain_.begin(), domain_.end(), 
		      cell_code(c.code()).extend_high(decomp_level));
	hightask = thi-domain_.begin()-1;
      }
      else
	hightask = lowtask;
      inside = lowtask<=task && hightask>=task;
    }

    if (*(refined++)) {
      // cell is refined. what we do depends on whether it's in our
      // domain or not.
      if(inside) {
	c.cell()->task_ = -1;
	c.cell()->refine(placement);
	// by refining the cell, we know that the next call to dfs()
	// will descend
      }
      else {
	// if there is a refinement in the structure that we don't do,
	// we need to spool up the structure iterator past that
	// subtree
	skip_suboctree(refined, e);

	// we also need to assign a task to this cell, so we know
	// where to send the ray in case we end up there. If the cell
	// encompasses several domains, we flag it for later
	c->task_ = (lowtask==hightask) ? lowtask : -1;
      }
    }
    else {
      // leaf cell. set task field
      assert(!inside || lowtask==hightask);
      c->task_ = (lowtask==hightask) ? lowtask : -1;
      if(c->task_==task) ++n;
    }

    // step to the next cell in the depth-first traversal
    c.dfs();
  }
  cout << "Task " << task << " created " << n << " leaf cells, should be " << domain_index().second << endl;
}

#endif
